#!/usr/bin/env python
"""
BrowseComp benchmark with proper decryption.

This script runs the BrowseComp benchmark with proper decryption using the canary field.

Usage:
    # Install dependencies with PDM
    cd /path/to/local-deep-research
    pdm install

    # Run the script with PDM
    pdm run python examples/benchmarks/run_browsecomp.py --help
"""

import argparse
import base64
import hashlib
import json
import os
import re
import sys
import time
from pathlib import Path
from typing import Any, Dict

from loguru import logger

from local_deep_research.api import quick_summary
from local_deep_research.benchmarks.datasets import load_dataset
from local_deep_research.benchmarks.graders import grade_results
from local_deep_research.benchmarks.templates import BROWSECOMP_QUERY_TEMPLATE


def derive_key(password: str, length: int) -> bytes:
    """Derive a fixed-length key from the password using SHA256."""
    hasher = hashlib.sha256()
    hasher.update(password.encode())
    key = hasher.digest()
    return key * (length // len(key)) + key[: length % len(key)]


def decrypt(ciphertext_b64: str, password: str) -> str:
    """Decrypt base64-encoded ciphertext with XOR."""
    try:
        encrypted = base64.b64decode(ciphertext_b64)
        key = derive_key(password, len(encrypted))
        decrypted = bytes(a ^ b for a, b in zip(encrypted, key, strict=False))
        return decrypted.decode()
    except Exception as e:
        logger.exception(f"Error decrypting data: {e!s}")
        return f"Error: Could not decrypt data - {e!s}"


def run_browsecomp_with_canary(
    num_examples: int = 5,
    search_iterations: int = 1,
    questions_per_iteration: int = 1,
    search_tool: str = "searxng",
    output_dir: str = "browsecomp_results",
) -> Dict[str, Any]:
    """
    Run BrowseComp benchmark with proper decryption using canary field.

    Args:
        num_examples: Number of examples to evaluate
        search_iterations: Number of search iterations per query
        questions_per_iteration: Number of questions per iteration
        search_tool: Search engine to use
        output_dir: Directory to save results

    Returns:
        Dictionary with benchmark results
    """
    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Load BrowseComp dataset
    dataset = load_dataset(
        dataset_type="browsecomp",
        num_examples=num_examples,
        seed=42,
    )

    # Set up output files
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    results_file = str(
        Path(output_dir) / f"browsecomp_{timestamp}_results.jsonl"
    )
    evaluation_file = str(
        Path(output_dir) / f"browsecomp_{timestamp}_evaluation.jsonl"
    )

    # Make sure output files don't exist
    for file in [results_file, evaluation_file]:
        if Path(file).exists():
            os.remove(file)

    # Process each example
    results = []
    total_examples = len(dataset)

    for i, example in enumerate(dataset):
        # Decrypt the problem and answer using the canary
        try:
            encrypted_question = example.get("problem", "")
            encrypted_answer = example.get("answer", "")
            canary = example.get("canary", "")

            # Decrypt question and answer
            decrypted_question = decrypt(encrypted_question, canary)
            decrypted_answer = decrypt(encrypted_answer, canary)

            logger.info(
                f"Processing {i + 1}/{total_examples}: {decrypted_question[:50]}..."
            )

            # Format query for BrowseComp
            formatted_query = BROWSECOMP_QUERY_TEMPLATE.format(
                question=decrypted_question
            )

            # Time the search
            start_time = time.time()

            # Get response from LDR
            search_result = quick_summary(
                query=formatted_query,
                iterations=search_iterations,
                questions_per_iteration=questions_per_iteration,
                search_tool=search_tool,
            )

            end_time = time.time()
            processing_time = end_time - start_time

            # Extract response
            response = search_result.get("summary", "")

            # Extract exact answer from the response
            answer_match = re.search(r"Exact Answer:\s*(.*?)(?:\n|$)", response)
            exact_answer = (
                answer_match.group(1).strip() if answer_match else "None"
            )

            # Extract confidence from the response
            confidence_match = re.search(r"Confidence:\s*(\d+)%", response)
            confidence = (
                confidence_match.group(1) if confidence_match else "100"
            )

            # Format result
            result = {
                "id": example.get("id", f"example_{i}"),
                "problem": decrypted_question,  # Store decrypted question
                "correct_answer": decrypted_answer,  # Store decrypted answer
                "response": response,
                "extracted_answer": exact_answer,
                "confidence": confidence,
                "processing_time": processing_time,
                "sources": search_result.get("sources", []),
                "search_config": {
                    "iterations": search_iterations,
                    "questions_per_iteration": questions_per_iteration,
                    "search_tool": search_tool,
                },
            }

            # Add to results list
            results.append(result)

            # Write result to file
            with open(results_file, "a") as f:
                f.write(json.dumps(result) + "\n")

        except Exception as e:
            logger.exception(f"Error processing example {i + 1}: {e!s}")

            # Create error result
            error_result = {
                "id": example.get("id", f"example_{i}"),
                "problem": (
                    decrypted_question
                    if "decrypted_question" in locals()
                    else "Error: Could not decrypt problem"
                ),
                "correct_answer": (
                    decrypted_answer
                    if "decrypted_answer" in locals()
                    else "Error: Could not decrypt answer"
                ),
                "error": str(e),
                "processing_time": (
                    time.time() - start_time if "start_time" in locals() else 0
                ),
            }

            # Add to results list
            results.append(error_result)

            # Write error result to file
            with open(results_file, "a") as f:
                f.write(json.dumps(error_result) + "\n")

    logger.info(f"Completed processing {total_examples} examples")

    # Run evaluation
    logger.info("Running automated evaluation...")
    try:
        evaluation_results = grade_results(
            results_file=results_file,
            output_file=evaluation_file,
            dataset_type="browsecomp",
        )
    except Exception as e:
        logger.exception(f"Evaluation failed: {e!s}")
        evaluation_results = []

    # Calculate basic metrics
    correct_count = sum(
        1 for result in evaluation_results if result.get("is_correct", False)
    )
    accuracy = correct_count / len(results) if results else 0
    avg_time = (
        sum(result.get("processing_time", 0) for result in results)
        / len(results)
        if results
        else 0
    )

    print("\nBrowseComp Benchmark Results:")
    print(f"  Accuracy: {accuracy:.3f}")
    print(f"  Total examples: {total_examples}")
    print(f"  Correct answers: {correct_count}")
    print(f"  Average time: {avg_time:.2f}s")
    print()
    print(f"Report saved to: {evaluation_file}")

    return {
        "status": "complete",
        "dataset_type": "browsecomp",
        "results_path": results_file,
        "evaluation_path": evaluation_file,
        "metrics": {"accuracy": accuracy, "average_processing_time": avg_time},
        "total_examples": total_examples,
        "accuracy": accuracy,
    }


def main():
    """Run the BrowseComp benchmark with command-line arguments."""
    parser = argparse.ArgumentParser(
        description="Run BrowseComp benchmark with proper decryption"
    )
    parser.add_argument(
        "--examples", type=int, default=2, help="Number of examples to run"
    )
    parser.add_argument(
        "--iterations", type=int, default=1, help="Number of search iterations"
    )
    parser.add_argument(
        "--questions", type=int, default=1, help="Questions per iteration"
    )
    parser.add_argument(
        "--search-tool", type=str, default="searxng", help="Search tool to use"
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default=str(Path("examples") / "benchmarks" / "results" / "browsecomp"),
        help="Output directory",
    )

    args = parser.parse_args()

    run_browsecomp_with_canary(
        num_examples=args.examples,
        search_iterations=args.iterations,
        questions_per_iteration=args.questions,
        search_tool=args.search_tool,
        output_dir=args.output_dir,
    )

    return 0


if __name__ == "__main__":
    sys.exit(main())
